from .attn_cls_model import AttnCls

__factory = {
    'attn_cls': AttnCls,
}


def create_model(arch_name, *args, **kwargs):
    return __factory[arch_name](*args, **kwargs)


if __name__ == '__main__':
    model = create_model('attn_cls', num_classes=11, convs_depth=101, convs_pretrained=True)
    print(model)
